"""Ablation study runner for MARL AutoML framework.
Each ablation is implemented WITHOUT modifying core framework code.
We wrap or monkey-patch lightweight behaviors at runtime.

Ablations:
  full: Full framework (baseline)
  no_teacher: Disable teacher (student acts alone)
  no_component_credit: Disable component-based credit assignment usage
  no_adaptive_exploration: Freeze epsilons (no decay / adaptive logic)
  no_pipeline_memory: Disable environment pipeline memory guidance

We achieve these by minimal wrappers.
"""
import copy
import json
from marl.train import marl_training
from marl.agents.teacher import TeacherAgent
from marl.agents.student import StudentAgent
from marl.utils.credit_assignment import CreditAssignment
from marl.environments.pipeline_env import PipelineEnvironment
from marl.environments.ml_components import COMPONENT_MAP
from experiments.utils import load_dataset_safely, seed_everything

# Lightweight helper to run training with hooks

def run_full(dataset: str, episodes: int, eval_timeout: int = 300):
    env = marl_training(dataset_name=dataset, episodes=episodes, eval_timeout=eval_timeout)
    return env.get_pipeline_statistics()


def run_no_teacher(dataset: str, episodes: int, eval_timeout: int = 300):
    seed_everything(42)
    data, msg = load_dataset_safely(dataset)
    if data is None:
        raise RuntimeError(msg)
    env = PipelineEnvironment(data, available_components=list(COMPONENT_MAP.keys()), max_pipeline_length=8, debug=False, eval_timeout=eval_timeout)
    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.n
    student = StudentAgent(state_dim, action_dim, {'epsilon': 1.0, 'epsilon_min': 0.1})
    credit_assigner = CreditAssignment()
    best_perf = 0
    for ep in range(episodes):
        state = env.reset()
        done = False
        pipeline = []
        student_actions = []
        teacher_interventions = []  # always False
        while not done:
            valid = env.get_filtered_actions()
            if not valid:
                break
            a = student.act(state, valid, env=env)
            next_state, reward, done, info = env.step(a)
            student.learn(state, a, reward, next_state, done)
            state = next_state
            pipeline.append(env.available_components[a])
            student_actions.append(a)
            teacher_interventions.append(False)
        perf = info.get('performance', 0)
        if perf > best_perf:
            best_perf = perf
        # Apply final reward (student full credit)
        student.apply_final_reward(perf, decay=0.95)
    return env.get_pipeline_statistics()


def run_no_component_credit(dataset: str, episodes: int, eval_timeout: int = 300):
    # Run normal training but we monkey patch CreditAssignment methods to noop component credit
    orig_assign_component = CreditAssignment.assign_component_credit
    def noop_assign(self, pipeline_components, performance, evaluate_fn):
        return {}  # no component credit
    CreditAssignment.assign_component_credit = noop_assign
    try:
        env = marl_training(dataset_name=dataset, episodes=episodes, eval_timeout=eval_timeout)
    finally:
        CreditAssignment.assign_component_credit = orig_assign_component
    return env.get_pipeline_statistics()


def run_no_adaptive_exploration(dataset: str, episodes: int, eval_timeout: int = 300):
    # Freeze epsilon decay by wrapping DoubleDQN update_model to skip epsilon update
    from marl.models.double_dqn import DoubleDQN
    orig_update = DoubleDQN.update_model
    def frozen_update(self):
        eps_before = self.epsilon
        orig_update(self)
        # Restore epsilon to freeze exploration
        self.epsilon = eps_before
    DoubleDQN.update_model = frozen_update

    # Disable teacher adaptive adjustments (epsilon tweaks and decaying intervention threshold)
    orig_analyze = TeacherAgent._analyze_interventions
    TeacherAgent._analyze_interventions = lambda self: None

    orig_teacher_act = TeacherAgent.act
    def act_no_decay(self, state, valid_actions=None, student_action=None, env=None):
        # Temporarily disable intervention decay for threshold schedule
        saved = getattr(self, 'intervention_decay', True)
        self.intervention_decay = False
        try:
            return orig_teacher_act(self, state, valid_actions, student_action, env)
        finally:
            self.intervention_decay = saved
    TeacherAgent.act = act_no_decay

    try:
        env = marl_training(dataset_name=dataset, episodes=episodes, eval_timeout=eval_timeout)
    finally:
        DoubleDQN.update_model = orig_update
        TeacherAgent._analyze_interventions = orig_analyze
        TeacherAgent.act = orig_teacher_act
    return env.get_pipeline_statistics()


def run_no_pipeline_memory(dataset: str, episodes: int, eval_timeout: int = 300):
    # Disable usage of env.pipeline_memory by monkey patching StudentAgent.act to ignore memory
    orig_act = StudentAgent.act
    def act_no_memory(self, state, valid_actions=None, teacher_feedback=None, pipeline_memory=None, env=None):
        return orig_act(self, state, valid_actions, teacher_feedback=None, pipeline_memory=None, env=None)
    StudentAgent.act = act_no_memory
    try:
        env = marl_training(dataset_name=dataset, episodes=episodes, eval_timeout=eval_timeout)
    finally:
        StudentAgent.act = orig_act
    return env.get_pipeline_statistics()

ABLATIONS = {
    'full': run_full,
    'no_teacher': run_no_teacher,
    'no_component_credit': run_no_component_credit,
    'no_adaptive_exploration': run_no_adaptive_exploration,
    'no_pipeline_memory': run_no_pipeline_memory,
}


def run_ablation_suite(dataset: str, episodes: int = 50, eval_timeout: int = 300):
    results = {}
    for name, fn in ABLATIONS.items():
        print(f"\n=== Running ablation: {name} ===")
        try:
            stats = fn(dataset, episodes, eval_timeout)
            results[name] = stats
        except Exception as e:
            results[name] = {"error": str(e)}
    return results

if __name__ == '__main__':
    import argparse, json
    p = argparse.ArgumentParser()
    p.add_argument('--dataset', default='iris')
    p.add_argument('--episodes', type=int, default=50)
    p.add_argument('--eval-timeout', type=int, default=300)
    args = p.parse_args()
    res = run_ablation_suite(args.dataset, args.episodes, args.eval_timeout)
    print(json.dumps(res, indent=2))
